package conf

import (
	"fmt"
	"math"
	"sync"

	"github.com/zeast/logs"
)

type limitConf struct {
	DefaultPer    int32
	UserLimit     map[string]int32
	UserNodeLimit map[string]map[string]int32
}

const maxTokenNum = 8192

type limit struct {
	sync.Mutex
	conf     limitConf
	nodeConf map[string]int32
	token    map[string]map[string]*nodeLimit
}

type nodeLimit struct {
	maxConnNum int32
	per        int32
	num        int32
	offset     int32
	ch         chan struct{}
}

var Limit = &limit{
	nodeConf: make(map[string]int32),
	token:    make(map[string]map[string]*nodeLimit),
}

func (nl *nodeLimit) adjust(user, alias string) {
	newNum := int32(Round(float64(nl.maxConnNum) * float64(nl.per) / float64(100)))

	if newNum > maxTokenNum {
		logs.Errorf("用户到 MySQL 的限制数据不能超过%d. user:%s, alias:%s, maxConnNum:%d, per:%d, num:%d", maxTokenNum, user, alias, nl.maxConnNum, nl.per, nl.num)
		newNum = maxTokenNum
	}

	logs.Debugf("修改 user到 MySQL 的连接限制. user:%s, alias:%s, old:%d, new:%d", user, alias, nl.num, newNum)
	if newNum > nl.num {
		for i := int32(0); i < newNum-nl.num; i++ {
			nl.ch <- struct{}{}
		}
	} else if newNum < nl.num {
		nl.offset = nl.num - newNum
	}

	nl.num = newNum
}

func (l *limit) ChangeNodeConf(alias string, maxConnNum int32) {
	l.Lock()
	defer l.Unlock()

	logs.Debugf("修改 node 的最大连接数. %s, %d", alias, maxConnNum)
	if _, ok := l.nodeConf[alias]; !ok {
		l.nodeConf[alias] = maxConnNum
		return
	}

	l.nodeConf[alias] = maxConnNum
	for user, nodesLimit := range l.token {
		for a, nl := range nodesLimit {
			if a == alias {
				nl.maxConnNum = maxConnNum
				nl.adjust(user, alias)
			}
		}
	}
}

func (l *limit) Get(user string, alias string) (chan struct{}, error) {
	l.Lock()
	defer l.Unlock()

	if _, ok := l.token[user]; !ok {
		l.token[user] = make(map[string]*nodeLimit)
	}

	if _, ok := l.token[user][alias]; !ok {
		var max int32
		if max, ok = l.nodeConf[alias]; !ok {
			return nil, fmt.Errorf("试图从一个不存在 node 配置的 限制表 中获取数据. user:%s, alias:%s", user, alias)
		}

		nl := &nodeLimit{
			maxConnNum: max,
			per:        l.per(user, alias),
			ch:         make(chan struct{}, maxTokenNum),
		}
		nl.adjust(user, alias)

		l.token[user][alias] = nl
	}

	ch := l.token[user][alias].ch

	return ch, nil
}

func (l *limit) per(user string, alias string) int32 {
	var per int32

	if find, p := l.userNodeLimit(user, alias); find {
		per = p

	} else if find, p := l.userLimit(user); find {
		per = p

	} else {
		per = l.conf.DefaultPer
	}

	return per
}

func (l *limit) Put(user string, alias string) {
	l.Lock()
	defer l.Unlock()

	if _, ok := l.token[user]; !ok {
		logs.Warnf("往一个不存在 user 的限制中返回数据,%s, %s", user, alias)
		return
	}

	if _, ok := l.token[user][alias]; !ok {
		logs.Warnf("往一个不存在 alias 的限制中返回数据,%s, %s", user, alias)
		return
	}

	nl := l.token[user][alias]
	if nl.offset > 0 {
		nl.offset--
		return
	}

	select {
	case nl.ch <- struct{}{}:
	default:
		logs.Errorf("用户返回连接超出管道限制, %s, %s", user, alias)
	}

}

func (l *limit) userNodeLimit(user string, alias string) (bool, int32) {
	if _, ok := l.conf.UserNodeLimit[user]; ok {
		if _, ok := l.conf.UserNodeLimit[user][alias]; ok {
			return true, l.conf.UserNodeLimit[user][alias]
		}
	}

	return false, 0
}

func (l *limit) userLimit(user string) (bool, int32) {
	if _, ok := l.conf.UserLimit[user]; ok {
		return true, l.conf.UserLimit[user]
	}

	return false, 0
}

func (l *limit) Update(lc limitConf) {
	l.Lock()
	l.conf = lc

	for user, nodesLimit := range l.token {
		for alias, nl := range nodesLimit {
			nl.per = l.per(user, alias)
			nl.adjust(user, alias)
		}
	}

	l.Unlock()
}

// some constants copied from https://github.com/golang/go/blob/master/src/math/bits.go
const (
	shift = 64 - 11 - 1
	bias  = 1023
	mask  = 0x7FF
)

// Round returns the nearest integer, rounding half away from zero.
// This function is available natively in Go 1.10
//
// Special cases are:
//	Round(±0) = ±0
//	Round(±Inf) = ±Inf
//	Round(NaN) = NaN

func Round(x float64) float64 {
	// Round is a faster implementation of:
	//
	// func Round(x float64) float64 {
	//   t := Trunc(x)
	//   if Abs(x-t) >= 0.5 {
	//     return t + Copysign(1, x)
	//   }
	//   return t
	// }
	const (
		signMask = 1 << 63
		fracMask = 1<<shift - 1
		half     = 1 << (shift - 1)
		one      = bias << shift
	)

	bits := math.Float64bits(x)
	e := uint(bits>>shift) & mask
	if e < bias {
		// Round abs(x) < 1 including denormals.
		bits &= signMask // +-0
		if e == bias-1 {
			bits |= one // +-1
		}
	} else if e < bias+shift {
		// Round any abs(x) >= 1 containing a fractional component [0,1).
		//
		// Numbers with larger exponents are returned unchanged since they
		// must be either an integer, infinity, or NaN.
		e -= bias
		bits += half >> e
		bits &^= fracMask >> e
	}
	return math.Float64frombits(bits)
}
